use std::{
    pin::Pin,
    task::{Context, Poll},
};

use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use pyo3::{
    IntoPyObjectExt,
    exceptions::PyTypeError,
    intern,
    prelude::*,
    pybacked::{PyBackedBytes, PyBackedStr},
};
use tokio::sync::mpsc::{self, error::TryRecvError};

use crate::{buffer::PyBuffer, client::future::PyFuture, error::Error, rt::Runtime};

/// Represents a Python streaming body, either synchronous or asynchronous.
pub enum PyStream {
    Sync(Py<PyAny>),
    Async(Pin<Box<dyn Stream<Item = Py<PyAny>> + Send + 'static>>),
}

/// A byte stream response.
/// An asynchronous iterator yielding data chunks from the response stream.
/// Used to stream response content.
/// Implemented in the `stream` method of the `Response` class.
/// Can be used in an asynchronous for loop in Python.
#[pyclass(subclass)]
pub struct Streamer(mpsc::Receiver<wreq::Result<Bytes>>);

// ===== impl Streamer =====

impl Streamer {
    /// Create a new [`Streamer`] instance.
    pub fn new(stream: impl Stream<Item = wreq::Result<Bytes>> + Send + 'static) -> Streamer {
        let (tx, rx) = mpsc::channel(8);
        Runtime::spawn(async move {
            futures_util::pin_mut!(stream);
            while let Some(item) = stream.next().await {
                if tx.send(item).await.is_err() {
                    break;
                }
            }
        });

        Streamer(rx)
    }
}

/// Asynchronous iterator implementation for `Streamer`.
#[pymethods]
impl Streamer {
    #[inline]
    fn __aiter__(slf: PyRef<Self>) -> PyRef<Self> {
        slf
    }

    fn __anext__<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
        let res = py.detach(|| match self.0.try_recv() {
            Ok(res) => res.map(PyBuffer::from).map(Some).map_err(Error::Library),
            Err(err) => match err {
                TryRecvError::Empty => Ok(None),
                TryRecvError::Disconnected => Err(Error::StopAsyncIteration),
            },
        })?;
        PyFuture::closure(py, move || Ok(res))
    }

    #[inline]
    fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
        let slf = slf.into_py_any(py)?;
        PyFuture::closure(py, move || Ok(slf))
    }

    #[inline]
    fn __aexit__<'py>(
        &mut self,
        py: Python<'py>,
        _exc_type: &Bound<'py, PyAny>,
        _exc_value: &Bound<'py, PyAny>,
        _traceback: &Bound<'py, PyAny>,
    ) -> PyResult<Bound<'py, PyAny>> {
        self.0.close();
        PyFuture::closure(py, move || Ok(()))
    }
}

/// Synchronous iterator implementation for `Streamer`.
#[pymethods]
impl Streamer {
    #[inline]
    fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
        slf
    }

    fn __next__(&mut self, py: Python) -> PyResult<PyBuffer> {
        py.detach(|| {
            self.0
                .blocking_recv()
                .ok_or(Error::StopIteration)?
                .map(PyBuffer::from)
                .map_err(Error::Library)
                .map_err(Into::into)
        })
    }

    #[inline]
    fn __enter__(slf: PyRef<Self>) -> PyRef<Self> {
        slf
    }

    #[inline]
    fn __exit__<'py>(
        &mut self,
        _exc_type: &Bound<'py, PyAny>,
        _exc_value: &Bound<'py, PyAny>,
        _traceback: &Bound<'py, PyAny>,
    ) {
        self.0.close();
    }
}

// ===== impl PyStream =====

impl FromPyObject<'_, '_> for PyStream {
    type Error = PyErr;

    /// Extracts a [`PyStream`] from a Python object.
    /// Accepts sync or async iterators.
    fn extract(ob: Borrowed<PyAny>) -> PyResult<Self> {
        if ob.hasattr(intern!(ob.py(), "asend"))? {
            crate::rt::Runtime::into_stream(ob)
                .map(Box::pin)
                .map(|stream| PyStream::Async(stream))
        } else {
            ob.extract::<Py<PyAny>>()
                .map(PyStream::Sync)
                .map_err(Into::into)
        }
    }
}

impl Stream for PyStream {
    type Item = PyResult<Bytes>;

    /// Yields the next chunk from the Python stream as bytes.
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match self.get_mut() {
            PyStream::Sync(iter) => Python::attach(|py| {
                let next = iter
                    .call_method0(py, intern!(py, "__next__"))
                    .ok()
                    .map(|item| extract_bytes(py, item));
                py.detach(|| Poll::Ready(next))
            }),
            PyStream::Async(stream) => {
                let waker = cx.waker();
                Python::attach(|py| {
                    py.detach(|| stream.as_mut().poll_next(&mut Context::from_waker(waker)))
                        .map(|item| item.map(|item| extract_bytes(py, item)))
                })
            }
        }
    }
}

/// Extracts a [`Bytes`] object from a Python object.
/// Accepts bytes-like or str-like objects, otherwise raises a `TypeError`.
#[inline]
fn extract_bytes(py: Python<'_>, ob: Py<PyAny>) -> PyResult<Bytes> {
    match ob.extract::<PyBackedBytes>(py) {
        Ok(chunk) => Ok(Bytes::from_owner(chunk)),
        Err(_) => ob
            .extract::<PyBackedStr>(py)
            .map(Bytes::from_owner)
            .map_err(|err| {
                PyTypeError::new_err(format!("Stream must yield bytes/str - like objects: {err}"))
            }),
    }
}
